package mcfall.raytracer.tests;

import mcfall.math.ColumnVector;
import mcfall.math.IncompatibleMatrixException;

import junit.framework.TestCase;

public class Matrix extends mcfall.raytracer.tests.TestCase {
	mcfall.math.Matrix matrix;
	mcfall.math.Matrix multiplierMatrix;
	
	static double[][] initialValues = {
			{0, 1, 2, 3},
			{4, 5, 6, 7},
			{8, 9, 10, 11}
	};
	
	static double[][] multiplier = {
		{-1, 1, -2},
		{0, 1, 0},
		{1, 2, 3},
		{3, 2, 1}
	};
	
	static double[][] preMultiplyResult = {
		{-12, -14, -16, -18},
		{4, 5, 6, 7},
		{32, 38, 44, 50},
		{16, 22, 28, 34}
	};
	
	static double[][] postMultiplyResult = {
		{11, 11, 9},
		{23, 35, 17},
		{35, 59, 25}
	};
	
	private int numRows;
	private int numColumns;
	
	public Matrix () {
		matrix = new mcfall.math.Matrix (initialValues.length, initialValues[0].length);
		multiplierMatrix = new mcfall.math.Matrix (multiplier.length, multiplier[0].length, multiplier);
		numRows = initialValues.length;
		numColumns = initialValues[0].length;
	}
	
	public void setUp () {
		matrix.setFirstColumnIndex(1);
		matrix.setFirstRowIndex(1);
		for (int r = 1; r <= initialValues.length; r++) {
			for (int c = 1; c <= initialValues[0].length; c++) {
				matrix.setValueAt(r, c, initialValues[r-1][c-1]);
			}
		}		
	}
	
	/*
	 * Test method for 'mcfall.raytracer.Matrix ()
	 */
	public void testTwoArgumentConstructor () {
		mcfall.math.Matrix matrix = new mcfall.math.Matrix (3,4);
		assertEquals (1, matrix.getFirstColumnIndex());
		assertEquals (4, matrix.getLastColumnIndex());
		assertEquals (1, matrix.getFirstRowIndex());
		assertEquals (3, matrix.getLastRowIndex());
		
		for (int row = 1; row <= 3; row++) {
			for (int col = 1; col <= 4; col++) {
				assertEquals (0.0, matrix.getValueAt(row, col));
			}
		}
	}
	
	/**
	 * 	 
	 */
	public void testThreeArgumentConstructor () {
	
		int numRows = initialValues.length;
		int numColumns = initialValues[0].length;
		
		matrix = new mcfall.math.Matrix (numRows, numColumns, initialValues);
		for (int i = 0; i < numRows * numColumns; i++) {
			assertEquals (initialValues[i/numColumns][i%numColumns], matrix.getValueAt(i/numColumns+1, i%numColumns+1));
		}
	}
	
	/**
	 * Test method for 'mcfall.raytracer.Matrix.equals (Object)
	 * Ensures that trying to test against a non matrix object returns
	 * false.
	 * Ensures that trying to test two matrices with different numbers of
	 * rows returns false
	 * Ensures that trying to test two matrices with different numbers of
	 * columns returns false
	 * Ensures that testing two compatible matrices with different elements
	 * returns false (do we test differences in every possible position)?
	 * Ensures that testing a matrix with itself returns true
	 * Ensures that testing a matrix with a different matrix with the same
	 * values returns true
	 */
	public void testEquals () {
		Object o = new Object ();
		assertFalse ("Matrix tested against non-matrix returns false", matrix.equals(o));
		
		mcfall.math.Matrix testMatrix = new mcfall.math.Matrix (matrix.getNumberOfRows()+1, matrix.getNumberOfColumns());
		assertFalse ("Matrix with different number of rows not equal", matrix.equals(testMatrix));
		
		testMatrix = new mcfall.math.Matrix (matrix.getNumberOfRows(), matrix.getNumberOfColumns()+1);
		assertFalse ("Matrix with different number of columns not equal", matrix.equals(testMatrix));
		
		testMatrix = new mcfall.math.Matrix (matrix.getNumberOfRows(), matrix.getNumberOfColumns(), initialValues);
		assertTrue ("Two distinct equivalent matrices are equal", matrix.equals(testMatrix));
		
		testMatrix.setValueAt(testMatrix.getFirstRowIndex(), testMatrix.getFirstColumnIndex(), 1+testMatrix.getValueAt(testMatrix.getFirstRowIndex(), testMatrix.getFirstColumnIndex()));
		assertFalse ("Same size matrix with different upper left not equals", matrix.equals(testMatrix));
		
		assertTrue ("Matrix compared to itself is equal", matrix.equals(matrix));
	}
	/*
	 * Test method for 'mcfall.raytracer.Matrix.getValueAt(int, int)'
	 */
	public void testGetValueAt() {
		//  First test the boundary conditions using 1 based indexing
		assertEquals (initialValues[0][0], matrix.getValueAt(1, 1));		
		assertEquals (initialValues[numRows-1][numColumns-1], matrix.getValueAt(numRows, numColumns));
		
		//  Now change to 0-based indexing
		matrix.setFirstColumnIndex(0);
		matrix.setFirstRowIndex(0);
		assertEquals (initialValues[0][0], matrix.getValueAt(0, 0));
		assertEquals (initialValues[numRows-1][numColumns-1], matrix.getValueAt(numRows-1, numColumns-1));
		
		//  Finally, change to a negative number
		int firstColumnIndex = -1;
		int firstRowIndex = -2;
		
		matrix.setFirstColumnIndex(firstColumnIndex);
		matrix.setFirstRowIndex(firstRowIndex);
		assertEquals (initialValues[0][0], matrix.getValueAt(firstRowIndex, firstColumnIndex));
		assertEquals (initialValues[numRows-1][numColumns-1], matrix.getValueAt(firstRowIndex+numRows-1, firstColumnIndex+numColumns-1));
	}

	/*
	 * Test method for 'mcfall.raytracer.Matrix.setValueAt(int, int, double)'
	 */
	public void testSetValueAt() {
		//  Set the boundary values using 1 based indexing
		double newValue = 10.0;
		
		double oldValue = matrix.setValueAt (1, 1, newValue);
		assertEquals (initialValues[0][0], oldValue);
		assertEquals (newValue, matrix.getValueAt (1, 1));
		
		oldValue = matrix.setValueAt(numRows, numColumns, newValue);
		assertEquals ("1-based indexing: Correct old value returned from setValueAt", initialValues[numRows-1][numColumns-1], oldValue);
		assertEquals ("1-based indexing: New value correctly returned after setValueAt", newValue, matrix.getValueAt(numRows, numColumns));
		
		//  Set the boundary values using 0 based indexing
		matrix.setFirstColumnIndex(0);
		matrix.setFirstRowIndex(0);
		
		oldValue = matrix.setValueAt(0, 0, initialValues[0][0]);
		assertEquals ("0-based indexing: Correct old value returned from setValueAt", newValue, oldValue);
		assertEquals ("0-based indexing: New value correctly returned after setValueAt", initialValues[0][0], matrix.getValueAt(0, 0));
		
		oldValue = matrix.setValueAt(numRows-1, numColumns-1, initialValues[numRows-1][numColumns-1]);
		assertEquals ("0-based indexing: Correct old value returned from setValueAt", newValue, oldValue);
		assertEquals ("0-based indexing: New value correctly returned after setValueAt", initialValues[numRows-1][numColumns-1], matrix.getValueAt(numRows-1, numColumns-1));
		
		//  Set the boundary values using negative number based indexing
		matrix.setFirstColumnIndex(-1);
		matrix.setFirstRowIndex(-2);
		oldValue = matrix.setValueAt (matrix.getFirstRowIndex(), matrix.getFirstColumnIndex(), newValue);
		assertEquals ("-2 based indexing: Correct old value returned from setValueAt", initialValues[0][0], oldValue);
		assertEquals ("-2 based indexing: New value correctly returned after setValueAt", newValue, matrix.getValueAt(matrix.getFirstRowIndex(), matrix.getFirstColumnIndex()));
		
		oldValue = matrix.setValueAt(matrix.getLastRowIndex(), matrix.getLastColumnIndex(), newValue);
		assertEquals ("-2 based indexing: Correct old value returned from setValueAt", initialValues[numRows-1][numColumns-1], oldValue);
		assertEquals ("-2 based indexing: New value correctly returned after setValueAt", newValue, matrix.getValueAt(matrix.getLastRowIndex(), matrix.getLastColumnIndex()));
	}

	/*
	 * Test method for 'mcfall.raytracer.Matrix.transpose()'
	 */
	public void testTranspose() {
		mcfall.math.Matrix transposeMatrix = matrix.transpose();
		
		assertEquals (matrix.getFirstColumnIndex(), transposeMatrix.getFirstRowIndex());
		assertEquals (matrix.getFirstRowIndex(), transposeMatrix.getFirstColumnIndex());
		
		transposeMatrix.setFirstColumnIndex(0);
		transposeMatrix.setFirstRowIndex(0);
		
		matrix.setFirstColumnIndex(0);
		matrix.setFirstRowIndex(0);
		
		assertEquals (numColumns, transposeMatrix.getNumberOfRows());
		assertEquals (numRows, transposeMatrix.getNumberOfColumns());
		for (int row = 0; row < numRows; row++) {
			for (int col = 0; col < numColumns; col++) {
				assertEquals (matrix.getValueAt(row, col), transposeMatrix.getValueAt(col, row));
			}
		}
	}

	/*
	 * Test method for 'mcfall.raytracer.Matrix.premultiply(Matrix)'
	 */
	public void testPremultiply() {		
		try {
			mcfall.math.Matrix result = matrix.premultiply(multiplierMatrix);
			assertEquals (multiplierMatrix.getNumberOfRows(), result.getNumberOfRows());
			assertEquals (matrix.getNumberOfColumns(), result.getNumberOfColumns());
			assertEquals (matrix.getFirstColumnIndex(), result.getFirstColumnIndex());
			assertEquals (matrix.getFirstRowIndex(), result.getFirstRowIndex());
			
			result.setFirstColumnIndex(0);
			result.setFirstRowIndex(0);
			
			for (int r = 0; r < result.getNumberOfRows(); r++) {
				for (int c = 0; c < result.getNumberOfColumns(); c++) {
					assertEquals (preMultiplyResult[r][c], result.getValueAt(r,c));
				}
			}
		}
		catch (IncompatibleMatrixException invalidException) {
			fail ("Incompatible matrix exception thrown when it shouldn't have been");
		}
		
		//  Now ensure that an incompatible matrix exception is thrown when it should be
		mcfall.math.Matrix incompatible = new mcfall.math.Matrix (matrix.getNumberOfRows()-1, matrix.getNumberOfColumns());
		try {
			matrix.premultiply(incompatible);
		}
		catch (IncompatibleMatrixException correctException) {
			return;
		}
		
		fail ("Did not throw incompatible matrix exception for an illegal matrix pre-multiplication");
		
	}

	/*
	 * Test method for 'mcfall.raytracer.Matrix.postmultiply(Matrix)'
	 */
	public void testPostmultiply() {
		try {
			mcfall.math.Matrix result = matrix.postmultiply(multiplierMatrix);
			assertEquals (matrix.getNumberOfRows(), result.getNumberOfRows());
			assertEquals (multiplierMatrix.getNumberOfColumns(), result.getNumberOfColumns());
			assertEquals (matrix.getFirstColumnIndex(), result.getFirstColumnIndex());
			assertEquals (matrix.getFirstRowIndex(), result.getFirstRowIndex());
			result.setFirstColumnIndex(0);
			result.setFirstRowIndex(0);
			for (int r = 0; r < result.getNumberOfRows(); r++) {
				for (int c = 0; c < result.getNumberOfColumns(); c++) {
					assertEquals (postMultiplyResult[r][c], result.getValueAt(r,c));
				}
			}
		}
		catch (IncompatibleMatrixException invalidException) {
			fail ("Invalid IncompatibleMatrixException thrown when post-multiplying compatible matrices");
		}
		
		//  Now test to ensure that an appropriate incompatible matrix exception is thrown
		mcfall.math.Matrix incompatible = new mcfall.math.Matrix (matrix.getNumberOfColumns()-1, matrix.getNumberOfRows());
		try {
			matrix.postmultiply (incompatible);
		}
		catch (IncompatibleMatrixException validException) {
			return;
		}
		
		fail ("Failed to throw an incompatible matrix exception when postmultiplying two incompatible matrices");
	}

	public void testToArray () {
		double[][] results = matrix.toArray();
		//  Make sure the values are the same
		for (int r = 0; r < results.length; r++) {
			for (int c = 0; c < results[0].length; c++) {
				assertEquals (matrix.getValueAt(matrix.getFirstRowIndex()+r, matrix.getFirstColumnIndex()+c), results[r][c]);
			}
		}
		
		//  Make sure changes to results don't affect the original data values
		results[0][0] = -1.0;
		assertEquals (matrix.getValueAt(matrix.getFirstRowIndex(), matrix.getFirstColumnIndex()), initialValues[0][0]);
		
		//  Make sure changes to the original data values don't change results
		matrix.setValueAt(matrix.getFirstRowIndex(), matrix.getFirstColumnIndex(), 0);
		assertEquals (-1.0, results[0][0]);
	}
	
	/*
	 * Test method for 'mcfall.raytracer.Matrix.scalarMultiply(double)'
	 */
	public void testScalarMultiply() {
		double scalar = 2.0;
		matrix.setFirstColumnIndex(1);
		matrix.setFirstRowIndex(1);
		mcfall.math.Matrix result = matrix.scalarMultiply(scalar);

		//  Make sure that the original array values have not changed
		for (int row = 1; row <= numRows; row++) {
			for (int col = 1; col <= numColumns; col++) {
				assertEquals ("Original values unchanged by scalarMultiply", initialValues[row-1][col-1], matrix.getValueAt(row, col));
			}
		}
		
		//  Make sure the index bounds on the return matrix are correct
		assertEquals (matrix.getFirstRowIndex(), result.getFirstRowIndex());
		assertEquals (matrix.getFirstColumnIndex(), result.getFirstColumnIndex());
		assertEquals (matrix.getLastRowIndex(), result.getLastRowIndex());
		assertEquals (matrix.getLastColumnIndex(), result.getLastColumnIndex());
		
		//  And make sure that all of the values are correct
		for (int row=1; row <= numRows; row++) {
			for (int col = 1; col <= numColumns; col++) {
				assertEquals (scalar*initialValues[row-1][col-1], result.getValueAt(row, col));
			}
		}
	}
	
	public void testRotateAroundAxis () throws Exception {
		//  Test some simple ones.  First is a 90 degree rotation around the z axis

		mcfall.math.Matrix rotationMatrix = mcfall.math.Matrix.createRotationMatrix(90, new mcfall.math.ColumnVector (4, new double[] {0, 0, 1, 0}));
		assertAlmostEquals (0, rotationMatrix.getValueAt(1, 1));
		assertAlmostEquals (-1, rotationMatrix.getValueAt(1, 2));
		assertAlmostEquals (0, rotationMatrix.getValueAt(1, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(1, 4));
		
		assertAlmostEquals (1, rotationMatrix.getValueAt(2, 1));
		assertAlmostEquals (0, rotationMatrix.getValueAt(2, 2));
		assertAlmostEquals (0, rotationMatrix.getValueAt(2, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(2, 4));
		
		assertAlmostEquals (0, rotationMatrix.getValueAt(3, 1));
		assertAlmostEquals (0, rotationMatrix.getValueAt(3, 2));
		assertAlmostEquals (1, rotationMatrix.getValueAt(3, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(3, 4));
		
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 1));
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 2));
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 3));
		assertAlmostEquals (1, rotationMatrix.getValueAt(4, 4));
		
		
		//  Try Example 5.3.4 out of the Hill textbook, page 241
		rotationMatrix = mcfall.math.Matrix.createRotationMatrix(45, new ColumnVector(4, new double[] {0.577, 0.577, 0.577, 0}));
		assertAlmostEquals (0.8047, rotationMatrix.getValueAt(1, 1));
		assertAlmostEquals (-0.31, rotationMatrix.getValueAt(1, 2));
		assertAlmostEquals (0.5058, rotationMatrix.getValueAt(1, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(1, 4));
		
		assertAlmostEquals (0.5058, rotationMatrix.getValueAt(2, 1));
		assertAlmostEquals (0.8047, rotationMatrix.getValueAt(2, 2));
		assertAlmostEquals (-0.31, rotationMatrix.getValueAt(2, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(2, 4));
		
		assertAlmostEquals (-0.31, rotationMatrix.getValueAt(3, 1));
		assertAlmostEquals (0.5058, rotationMatrix.getValueAt(3, 2));
		assertAlmostEquals (0.8047, rotationMatrix.getValueAt(3, 3));
		assertAlmostEquals (0, rotationMatrix.getValueAt(3, 4));
		
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 1));
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 2));
		assertAlmostEquals (0, rotationMatrix.getValueAt(4, 3));
		assertAlmostEquals (1, rotationMatrix.getValueAt(4, 4));
	}

}
